#中文对联
from transformers import BertTokenizer, GPT2LMHeadModel, TextGenerationPipeline

cache_dir = "../../my_model_cache/gpt2-chinese"
# tokenizer = BertTokenizer.from_pretrained("uer/gpt2-chinese-couplet", cache_dir=cache_dir)
# model = GPT2LMHeadModel.from_pretrained("uer/gpt2-chinese-couplet", cache_dir=cache_dir)
tokenizer = BertTokenizer.from_pretrained(
    cache_dir + r"/models--uer--gpt2-chinese-couplet/snapshots/91b9465fb1be617f69c6f003b0bd6e6642537bec")
model = GPT2LMHeadModel.from_pretrained(
    cache_dir + r"/models--uer--gpt2-chinese-couplet/snapshots/91b9465fb1be617f69c6f003b0bd6e6642537bec")
#device=0 指定当前的推理设备为第一块GPU;如果没有GPU环境，就去掉该参数
text_generator = TextGenerationPipeline(model, tokenizer, device=0)
out = text_generator("[CLS]十口心思，思乡思国思社稷 -", max_length=28, do_sample=True)
print(out)
